{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 通过神经网络预测房价\n",
    "在这个项目中，我们希望能够构建神经网络来预测房屋的价格"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先，我们导入一些必要的库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取训练集和测试集的数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('./all/train.csv')\n",
    "test = pd.read_csv('./all/test.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以具体看看前面 5 个训练集长什么样子，可以看到，前面都是这个房屋的属性，最后是房屋的价格"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Id</th>\n",
       "      <th>MSSubClass</th>\n",
       "      <th>MSZoning</th>\n",
       "      <th>LotFrontage</th>\n",
       "      <th>LotArea</th>\n",
       "      <th>Street</th>\n",
       "      <th>Alley</th>\n",
       "      <th>LotShape</th>\n",
       "      <th>LandContour</th>\n",
       "      <th>Utilities</th>\n",
       "      <th>...</th>\n",
       "      <th>PoolArea</th>\n",
       "      <th>PoolQC</th>\n",
       "      <th>Fence</th>\n",
       "      <th>MiscFeature</th>\n",
       "      <th>MiscVal</th>\n",
       "      <th>MoSold</th>\n",
       "      <th>YrSold</th>\n",
       "      <th>SaleType</th>\n",
       "      <th>SaleCondition</th>\n",
       "      <th>SalePrice</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>60</td>\n",
       "      <td>RL</td>\n",
       "      <td>65.0</td>\n",
       "      <td>8450</td>\n",
       "      <td>Pave</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Reg</td>\n",
       "      <td>Lvl</td>\n",
       "      <td>AllPub</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>2008</td>\n",
       "      <td>WD</td>\n",
       "      <td>Normal</td>\n",
       "      <td>208500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>RL</td>\n",
       "      <td>80.0</td>\n",
       "      <td>9600</td>\n",
       "      <td>Pave</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Reg</td>\n",
       "      <td>Lvl</td>\n",
       "      <td>AllPub</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>2007</td>\n",
       "      <td>WD</td>\n",
       "      <td>Normal</td>\n",
       "      <td>181500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>60</td>\n",
       "      <td>RL</td>\n",
       "      <td>68.0</td>\n",
       "      <td>11250</td>\n",
       "      <td>Pave</td>\n",
       "      <td>NaN</td>\n",
       "      <td>IR1</td>\n",
       "      <td>Lvl</td>\n",
       "      <td>AllPub</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>9</td>\n",
       "      <td>2008</td>\n",
       "      <td>WD</td>\n",
       "      <td>Normal</td>\n",
       "      <td>223500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>70</td>\n",
       "      <td>RL</td>\n",
       "      <td>60.0</td>\n",
       "      <td>9550</td>\n",
       "      <td>Pave</td>\n",
       "      <td>NaN</td>\n",
       "      <td>IR1</td>\n",
       "      <td>Lvl</td>\n",
       "      <td>AllPub</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>2006</td>\n",
       "      <td>WD</td>\n",
       "      <td>Abnorml</td>\n",
       "      <td>140000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>60</td>\n",
       "      <td>RL</td>\n",
       "      <td>84.0</td>\n",
       "      <td>14260</td>\n",
       "      <td>Pave</td>\n",
       "      <td>NaN</td>\n",
       "      <td>IR1</td>\n",
       "      <td>Lvl</td>\n",
       "      <td>AllPub</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>12</td>\n",
       "      <td>2008</td>\n",
       "      <td>WD</td>\n",
       "      <td>Normal</td>\n",
       "      <td>250000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 81 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   Id  MSSubClass MSZoning  LotFrontage  LotArea Street Alley LotShape  \\\n",
       "0   1          60       RL         65.0     8450   Pave   NaN      Reg   \n",
       "1   2          20       RL         80.0     9600   Pave   NaN      Reg   \n",
       "2   3          60       RL         68.0    11250   Pave   NaN      IR1   \n",
       "3   4          70       RL         60.0     9550   Pave   NaN      IR1   \n",
       "4   5          60       RL         84.0    14260   Pave   NaN      IR1   \n",
       "\n",
       "  LandContour Utilities    ...     PoolArea PoolQC Fence MiscFeature MiscVal  \\\n",
       "0         Lvl    AllPub    ...            0    NaN   NaN         NaN       0   \n",
       "1         Lvl    AllPub    ...            0    NaN   NaN         NaN       0   \n",
       "2         Lvl    AllPub    ...            0    NaN   NaN         NaN       0   \n",
       "3         Lvl    AllPub    ...            0    NaN   NaN         NaN       0   \n",
       "4         Lvl    AllPub    ...            0    NaN   NaN         NaN       0   \n",
       "\n",
       "  MoSold YrSold  SaleType  SaleCondition  SalePrice  \n",
       "0      2   2008        WD         Normal     208500  \n",
       "1      5   2007        WD         Normal     181500  \n",
       "2      9   2008        WD         Normal     223500  \n",
       "3      2   2006        WD        Abnorml     140000  \n",
       "4     12   2008        WD         Normal     250000  \n",
       "\n",
       "[5 rows x 81 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接着我们可以看看训练集和测试集分别有多少个样本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "一共有 1460 个训练集样本\n",
      "一共有 1459 个测试集样本\n"
     ]
    }
   ],
   "source": [
    "print('一共有 {} 个训练集样本'.format(train.shape[0]))\n",
    "print('一共有 {} 个测试集样本'.format(test.shape[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接着我们开始对数据进行处理，首先我们取出**第二个特征**到**倒数第二个特征**，这些特征作为我们神经网络的输入特征"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_features = pd.concat((train.loc[:, 'MSSubClass':'SaleCondition'],\n",
    "                          test.loc[:, 'MSSubClass':'SaleCondition']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接着我们需要进行数据标准化，对于所有的数值特征，我们都会减去均值，除以方差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "numeric_feats = all_features.dtypes[all_features.dtypes != \"object\"].index # 取出所有的数值特征\n",
    "\n",
    "# 减去均值，除以方差\n",
    "all_features[numeric_feats] = all_features[numeric_feats].apply(lambda x: (x - x.mean()) \n",
    "                                                                / (x.std()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对预测的价格取 log\n",
    "train['SalePrice'] = np.log(train['SalePrice'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果你仔细看看上面的特征，你会发现，除了数值特征之外，还有很多非数值特征，这些特征我们没有办法将其转换成数值表示，所以我们通过 pandas 的内置函数将其转换成种类表示\n",
    "\n",
    "比如 **MSZoning** 有两种可能，一种是 RL，一种是 RM，那么我们就将这个特征变成两个新的特征，RL 和 RM，如果这个数据在 **MSZoning** 上是 RL，那么 RL 取 1，RM 取 0；反之如果这个特征是 RM，那么 RL 取 0，RM 取 1.\n",
    "\n",
    "| RL | RM |\n",
    "|-|-|\n",
    "| 0 | 1 |\n",
    "| 1 | 0 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_features = pd.get_dummies(all_features, dummy_na=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "除此之外，我们会发现整个数据中有一些丢失数据，这些丢失数据都是 'NA'，我们没有办法将这些数据输入到网络中，所以需要对这些丢失数据进行赋值，这里我们将数据的均值填入到丢失数据中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_features = all_features.fillna(all_features.mean())\n",
    "feat_dim = all_features.shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "前面我们已经做好了数据的预处理，下面我们将所有的训练集和验证集都取出成为一个 numpy 的数组"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_train = int(0.9 * train.shape[0]) # 划分训练样本和验证集样本\n",
    "indices = np.arange(train.shape[0])\n",
    "np.random.shuffle(indices)  # shuffle 顺序\n",
    "train_indices = indices[:num_train]\n",
    "valid_indices = indices[num_train:]\n",
    "\n",
    "# 提取训练集和验证集的特征\n",
    "train_features = all_features.iloc[train_indices].values.astype(np.float32)\n",
    "train_features = torch.from_numpy(train_features)\n",
    "valid_features = all_features.iloc[valid_indices].values.astype(np.float32)\n",
    "valid_features = torch.from_numpy(valid_features)\n",
    "train_valid_features = all_features[:train.shape[0]].values.astype(np.float32)\n",
    "train_valid_features = torch.from_numpy(train_valid_features)\n",
    "\n",
    "# 提取训练集和验证集的label\n",
    "train_labels = train['SalePrice'].values[train_indices, None].astype(np.float32)\n",
    "train_labels = torch.from_numpy(train_labels)\n",
    "valid_labels = train['SalePrice'].values[valid_indices, None].astype(np.float32)\n",
    "valid_labels = torch.from_numpy(valid_labels)\n",
    "train_valid_labels = train['SalePrice'].values[:, None].astype(np.float32)\n",
    "train_valid_labels = torch.from_numpy(train_valid_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_features = all_features[train.shape[0]:].values.astype(np.float32)\n",
    "test_features = torch.from_numpy(test_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面是构建神经网络的地方，可以构建任意想要的神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=331, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "def get_model():\n",
    "    net = nn.Sequential(\n",
    "        nn.Linear(feat_dim, 1)\n",
    "    )\n",
    "    return net\n",
    "\n",
    "net = get_model()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在评估模型的时候，为了保证大的价格和小的价格对模型都有着近似相同的影响，我们不会直接使用前面定义的均方误差作为最后的评价函数，我们会对预测的价格和真实的价格取 log，然后计算他们之间均方误差的平方根来作为评价指标，这里的指标我们已经在 `utils.py` 中实现了，感兴趣的同学可以去看看。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import train_model, pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 10, train rmse: 0.176, valid rmse: 0.164\n",
      "\n",
      "epoch: 20, train rmse: 0.142, valid rmse: 0.145\n",
      "\n",
      "epoch: 30, train rmse: 0.125, valid rmse: 0.138\n",
      "\n",
      "epoch: 40, train rmse: 0.115, valid rmse: 0.132\n",
      "\n",
      "epoch: 50, train rmse: 0.109, valid rmse: 0.130\n",
      "\n",
      "epoch: 60, train rmse: 0.105, valid rmse: 0.123\n",
      "\n",
      "epoch: 70, train rmse: 0.102, valid rmse: 0.122\n",
      "\n",
      "epoch: 80, train rmse: 0.101, valid rmse: 0.119\n",
      "\n",
      "epoch: 90, train rmse: 0.099, valid rmse: 0.118\n",
      "\n",
      "epoch: 100, train rmse: 0.098, valid rmse: 0.116\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlsAAAFACAYAAACLPLm0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3X2UW/V95/HPVxrNjDQe+XF4spOYNClxIGBgYEkhhELaBUJI0vCUhpRl0+NuSgukybZkNz2btukpPe22aQ6UlAItaQg0NWHJ5kBIk0BolofEJgbMQ0ogpjZPHow9mrE0D5K++8fVzMhjjT2256c7vvf9OkdnpKur+/tIGs18dHV1r7m7AAAAEEYm7gAAAABJRtkCAAAIiLIFAAAQEGULAAAgIMoWAABAQJQtAACAgChbAAAAAVG2AAAAAqJsAQAABNQRd4Bmy5Yt85UrV8YdAwAAYK/Wr1//urv37W2+eVW2Vq5cqXXr1sUdAwAAYK/M7MXZzMfHiAAAAAFRtgAAAAKibAEAAAQ0r7bZAgAA89/4+Li2bNmikZGRuKO0RXd3t1asWKFcLrdft6dsAQCAfbJlyxb19vZq5cqVMrO44wTl7tq2bZu2bNmiI488cr+WwceIAABgn4yMjGjp0qWJL1qSZGZaunTpAa3Fo2wBAIB9loaiNeFA7ytlCwAAICDKFgAAOOjs2LFDf/u3f7vPtzv33HO1Y8eOAIlmlqqy9Y1vSPfdF3cKAABwoGYqW9VqdY+3u+eee7Ro0aJQsVpK1bcR/+Sy5/SmRcP6z5uPjzsKAAA4ANdcc42ef/55rV69WrlcTt3d3Vq8eLGeffZZ/fu//7s+9KEPafPmzRoZGdFVV12lNWvWSJo6NODw8LDOOeccnXbaaXrooYe0fPly3X333crn83OeNVVlK+8VVSpxpwAAIEGuvlrasGFul7l6tfTFL+5xlmuvvVYbN27Uhg0b9MADD+j973+/Nm7cOLl7hltuuUVLlixRpVLRSSedpI985CNaunTpLst47rnndPvtt+vv//7vddFFF+nOO+/UpZdeOrf3RSkrW4XcmMrjXXHHAAAAc+zkk0/eZT9YX/rSl3TXXXdJkjZv3qznnntut7J15JFHavXq1ZKkE088UZs2bQqSLVVlK5+r6Y2d+7f3VwAA0MJe1kC1S09Pz+T5Bx54QN/97nf18MMPq1Ao6Iwzzmi5n6yurqkVMNlsVpVAH3+lagP5QmdV5SprtgAAONj19vZqaGio5XWDg4NavHixCoWCnn32WT3yyCNtTrerVK3ZKnTVVK5RtgAAONgtXbpUp556qo455hjl83kdeuihk9edffbZ+vKXv6xVq1bpqKOO0imnnBJj0pSVrXxXTRWnbAEAkARf+9rXWk7v6urSvffe2/K6ie2yli1bpo0bN05O/8xnPjPn+Sak62PEble53h13DAAAkCKpKlv5vFRRXu5xJwEAAGmRqrJVyLtcGY3u3PPeZQEAAOZKuspWT3TU7vI29mwKAADaI1VlK1+IylZlx2jMSQAAQFqkqmwVeqO7W35j9x2bAQAAhBCsbJnZUWa2oelUMrOrQ403G/kF0Z4uWLMFAEC6LFiwQJL08ssv64ILLmg5zxlnnKF169bN+djB9rPl7j+VtFqSzCwr6SVJd4UabzYKvVlJUnlwPM4YAAAgJkcccYTWrl3b1jHb9THiWZKed/cX2zReS4ViY83W4FicMQAAwAG65pprdP31109e/vznP68vfOELOuuss3TCCSfoXe96l+6+++7dbrdp0yYdc8wxkqRKpaJLLrlEq1at0oc//OFgx0Zs1x7kL5F0e6srzGyNpDWS9OY3vzloiHwxOgh1ucSuHwAAmAtXXy1t2DC3y1y9eu/Ht7744ot19dVX64orrpAkff3rX9d9992nK6+8UsViUa+//rpOOeUUnX/++TKzlsu44YYbVCgU9Mwzz+iJJ57QCSecMLd3pCF42TKzTknnS/psq+vd/UZJN0pSf39/0N2NFhZ1SqJsAQBwsDv++OO1detWvfzyyxoYGNDixYt12GGH6VOf+pQefPBBZTIZvfTSS3rttdd02GGHtVzGgw8+qCuvvFKSdOyxx+rYY48NkrUda7bOkfSYu7/WhrH2KL8oOi5iZbgWcxIAAJJhb2ugQrrwwgu1du1avfrqq7r44ot12223aWBgQOvXr1cul9PKlSs1MhL/Hgjasc3WRzXDR4jtNrlma7gecxIAAHCgLr74Yt1xxx1au3atLrzwQg0ODuqQQw5RLpfT/fffrxdf3POm4qeffvrkwaw3btyoJ554IkjOoGu2zKxH0q9I+q2Q48xWYWleklTZSdkCAOBgd/TRR2toaEjLly/X4Ycfro997GP6wAc+oHe9613q7+/XO97xjj3e/pOf/KQuv/xyrVq1SqtWrdKJJ54YJGfQsuXuOyUtDTnGvsgv7pYklcsciRoAgCR48sknJ88vW7ZMDz/8cMv5hoeHJUkrV67Uxo0bJUn5fF533HFH8Iyp2oN8rphXVlWVd8adBAAApEWqypZlM8qrospI66+AAgAAzLVUlS1JKlhF5QplCwCAA+Genk1yDvS+pq5s5TOjqoxStgAA2F/d3d3atm1bKgqXu2vbtm3q7u7e72W0aw/y80YhM6ryaOruNgAAc2bFihXasmWLBgYG4o7SFt3d3VqxYsV+3z51raPQMaryWDbuGAAAHLRyuZyOPPLIuGMcNNL3MWLHuCpjqeuYAAAgJqkrW4WOcZWrubhjAACAlEhd2cp3VlWhbAEAgDZJXdkqdNZUrnbFHQMAAKRE+spWV03lOmULAAC0R+rKVr6rrgplCwAAtEnqylYh7yp7Pu4YAAAgJVJXtvJ5aVTdqteSv9dbAAAQv9SVrUIhKlmV0njMSQAAQBqksGxFx0Usb6vEnAQAAKRB6spWvie6y5XtIzEnAQAAaZC6slVYEN3l8vbRmJMAAIA0SF3ZyvdGx0Ws7KBsAQCA8FJXtgq9WUlSeZAN5AEAQHjpK1sLo+MiUrYAAEA7pK5s5YtR2aoMVWNOAgAA0iB1ZauwqFOSVC5RtgAAQHipK1v5RdFxESvDtZiTAACANAhatsxskZmtNbNnzewZM3t3yPFmo7A4Klvl4XrMSQAAQBp0BF7+30j6trtfYGadkgqBx9urwpJuSVJ5J8dGBAAA4QUrW2a2UNLpkv6LJLn7mKSxUOPNVn5JXpJUKVO2AABAeCE/RjxS0oCkfzCzn5jZTWbWE3C8WckWutSpUZU5NCIAAGiDkGWrQ9IJkm5w9+Ml7ZR0zfSZzGyNma0zs3UDAwMB40wOqLwqqlQs/FgAACD1QpatLZK2uPujjctrFZWvXbj7je7e7+79fX19AeNMKWRGVB6hbAEAgPCClS13f1XSZjM7qjHpLElPhxpvX0RlKxt3DAAAkAKhv434u5Jua3wT8QVJlwceb1bymTFVxihbAAAgvKBly903SOoPOcb+KHSMqkzZAgAAbZC6PchLUr5jXJXx0Cv1AAAAUlq2CrmqyuOdcccAAAApkM6y1VlVpZaLOwYAAEiBVJatfGdN5VpX3DEAAEAKpLJsFbooWwAAoD1SWbby3XVVvDvuGAAAIAVSWbYKeVfZ83HHAAAAKZDKspXPm6rKaXzM444CAAASLpVlq1CIfla2j8QbBAAAJF46y1ZPdBDq8rZKzEkAAEDSpbJs5Rtlq7JjNOYkAAAg6VJZtgoLouMilrdTtgAAQFipLFv53ui4iKzZAgAAoaWybBV6G2u2BsdjTgIAAJIunWVrYXRcRMoWAAAILZVlK98oW5XhWsxJAABA0qWybBUWRcdFLJeqMScBAABJl8qylW+ULdZsAQCA0FJZtgqLG2u2dnK4HgAAEFY6y9bS6CDUlC0AABBaKstW9+KobFXKlC0AABBWKsuWdeaUV1llDo0IAAACS2XZkqS8RlSpWNwxAABAwqW2bBUyFZVHUnv3AQBAm6S2bRSyoyqPZuOOAQAAEq4j5MLNbJOkIUk1SVV37w853r7IZ8ZUGaNsAQCAsIKWrYZfdvfX2zDOPil0jKk81o67DwAA0iy1HyPmO8ZVGadsAQCAsEKXLZf0HTNbb2ZrWs1gZmvMbJ2ZrRsYGAgcZ0qhc1zlaq5t4wEAgHQKXbZOc/cTJJ0j6QozO336DO5+o7v3u3t/X19f4DhTCp01latdbRsPAACkU9Cy5e4vNX5ulXSXpJNDjrcv8p01VeqdcccAAAAJF6xsmVmPmfVOnJf0q5I2hhpvXxW6ayrXu+OOAQAAEi7kFuKHSrrLzCbG+Zq7fzvgePsk3+Wq1PkYEQAAhBWsbLn7C5KOC7X8A1XIu8oqyF0yjtoDAAACSe2uHwoFyZXRaLkWdxQAAJBgqS1b+UK0OqvyRiXmJAAAIMlSW7YKPVHZKr8xEnMSAACQZKktW/me6K5XdozGnAQAACRZastWoTc6CHV5O2ULAACEk96yVYy+iFneMRZzEgAAkGSpLVv53qhsVUrjMScBAABJltqyVVgYHYS6PEjZAgAA4aS2bOWLUdmqDFdjTgIAAJIstWWrsDg6VE+5xE5NAQBAOKktW/lFUdmq7KRsAQCAcFJbtgpLuiVJ5WGPOQkAAEiy9JatpXlJUrlM2QIAAOGktmzlinllVFOlHHcSAACQZKktW5bNqKCyyhWLOwoAAEiw1JYtScrbiCochxoAAASU6rJVyIyoPJKNOwYAAEiwlJetUZVHKVsAACCcVJetfHZMlTHKFgAACCfVZauQG1N5rCPuGAAAIMFSXbbyHeOqVClbAAAgnFSXrUJnVeVqZ9wxAABAgqW8bNVUrnbFHQMAACRYqstWvrOmSp01WwAAIJzgZcvMsmb2EzP7Vuix9lWhu65yvTvuGAAAIMHasWbrKknPtGGcfZbvdlWcsgUAAMIJWrbMbIWk90u6KeQ4+6uQd40or3rN444CAAASKvSarS9K+n1J9cDj7JdCT/SzUhqPNwgAAEisWZUtM7vKzIoWudnMHjOzX93Lbc6TtNXd1+9lvjVmts7M1g0MDOxD9AOXz5skqbKdo1EDAIAwZrtm67+6e0nSr0paLOnjkq7dy21OlXS+mW2SdIekM83sq9Nncvcb3b3f3fv7+vpmn3wOFBZEd7+8rdLWcQEAQHrMtmxZ4+e5kv7J3Z9qmtaSu3/W3Ve4+0pJl0j6vrtfut9JA8j3RHe/MjgWcxIAAJBUsy1b683sO4rK1n1m1qt5uh3Wvij0RgehLm8fjTkJAABIqtkeGPATklZLesHdy2a2RNLlsx3E3R+Q9MA+pwusUIzufnkHa7YAAEAYs12z9W5JP3X3HWZ2qaTPSRoMF6s98r1R2eLbiAAAIJTZlq0bJJXN7DhJn5b0vKSvBEvVJoVF0aF6yqVqzEkAAEBSzbZsVd3dJX1Q0nXufr2k3nCx2iNfzEmSKsOULQAAEMZst9kaMrPPKtrlw3vMLCMpFy5WexQWd0mSykO1mJMAAICkmu2arYsljSra39arklZI+otgqdpkqmwd9F+sBAAA89SsylajYN0maWFjz/Aj7n7Qb7OVXxwdhLpS5tiIAAAgjNkeruciST+SdKGkiyQ9amYXhAzWDvmlBUlSeSdlCwAAhDHbbbb+p6ST3H2rJJlZn6TvSlobKlg7dPR0KacxVThaDwAACGS222xlJopWw7Z9uO38ZaaCyipX9njkIQAAgP022zVb3zaz+yTd3rh8saR7wkRqr4KNqDxC2QIAAGHMqmy5+383s49IOrUx6UZ3vytcrPbJZ0ZVGc3GHQMAACTUbNdsyd3vlHRnwCyxKGRHVaZsAQCAQPZYtsxsSFKrr+qZJHf3YpBUbZTvGFNljLIFAADC2GPZcveD/pA8e1PoGFN5fNYr+AAAAPbJwf+NwgOUz1VVHu+MOwYAAEio1JetBZ3jGq52xR0DAAAkVOrLVrEwrqFqPu4YAAAgoShbPTWV6gvijgEAABIq9WVrYa9ryHtVq8WdBAAAJFHqy1axsfOK4W2j8QYBAACJRNlaFD0EpZeGYk4CAACSiLK1JNrHVunVcsxJAABAElG2luYkSaXXKjEnAQAASUTZWhbt0HRwK9tsAQCAuZf6srXw0G5JUun1sZiTAACAJApWtsys28x+ZGaPm9lTZvZHocY6EMXDCpKk0rbxmJMAAIAkCnkE5lFJZ7r7sJnlJP3QzO5190cCjrnPikdEOzQtbWdHWwAAYO4FK1vu7pKGGxdzjZOHGm9/LTi8V5JUGpx30QAAQAIE3WbLzLJmtkHSVkn/6u6PtphnjZmtM7N1AwMDIeO0lC10aYGGVBqyto8NAACSL2jZcveau6+WtELSyWZ2TIt5bnT3fnfv7+vrCxmnNTMVbViDQ6n/rgAAAAigLQ3D3XdIul/S2e0Yb18t7BhWqZyNOwYAAEigkN9G7DOzRY3zeUm/IunZUOMdiGKuolIlF3cMAACQQCG/jXi4pFvNLKuo1H3d3b8VcLz9VuwcVWmkEHcMAACQQCG/jfiEpONDLX8uFbvH9NKOJXHHAAAACcRW4ZKKhXGVxlmzBQAA5h5lS9LCnpoG6wvijgEAABKIsiWp2Osa8l7V63EnAQAASUPZklRcGP0c3jYabxAAAJA4lC1JxYXRw1B6eXgvcwIAAOwbypak4pLoS5mlV3bGnAQAACQNZUtNZeu1csxJAABA0lC2JC08pEuSNPga22wBAIC5RdmSVDykW5JU2jYWcxIAAJA0lC1JxcOiHZqWtlVjTgIAAJKGsiWpeHiPJKm0oxZzEgAAkDSULUm9R/RKkko7POYkAAAgaShbkrI93erRsEqluJMAAICkoWxJkpkW2pAGh3k4AADA3KJdNBQ7dqq0syPuGAAAIGEoWw3FXEWlSi7uGAAAIGEoWw3FzhGVRjvjjgEAABKGstVQ7B5Xaaw77hgAACBhKFsNCwtjGhzviTsGAABIGMpWQ7GnrlKdsgUAAOYWZauhuKCuIV+gej3uJAAAIEkoWw3FhZIro51vjMYdBQAAJAhlq6G4KCtJKr08HHMSAACQJJSthuLiRtl6ZWfMSQAAQJIEK1tm9iYzu9/Mnjazp8zsqlBjzYWFS6O9xw++Wok5CQAASJKQx6epSvq0uz9mZr2S1pvZv7r70wHH3G/Fvi5JUmmAbbYAAMDcCbZmy91fcffHGueHJD0jaXmo8Q5U8ZBoh6al18diTgIAAJKkLdtsmdlKScdLerQd4+2P4mEFSVJp23jMSQAAQJIEL1tmtkDSnZKudvdSi+vXmNk6M1s3MDAQOs6MiodHOzQtba/FlgEAACRP0LJlZjlFRes2d/9Gq3nc/UZ373f3/r6+vpBx9qj3iF5JUmnQY8sAAACSJ+S3EU3SzZKecfe/CjXOXOlY0K0eDWtwt3VvAAAA+y/kmq1TJX1c0plmtqFxOjfgeAfGTEUbVmmYXY8BAIC5E2zXD+7+Q0kWavkhFDt2qrQzG3cMAACQIKzGaVLsqKhUycUdAwAAJAhlq0mxa0Slka64YwAAgAShbDVZ2D2q0lh33DEAAECCULaaFPNVDVYLcccAAAAJQtlqUuypqVTriTsGAABIEMpWk2Kvq+S9cvZrCgAA5ghlq0mxKLky2vnGaNxRAABAQlC2mhQXRbsFK72yM+YkAAAgKShbTRYujnZoStkCAABzhbLVpLisU5I0+Eo55iQAACApKFtNJspWaYBttgAAwNygbDUpHhLt0LT0+ljMSQAAQFJQtpoUD81LkkrbxmNOAgAAkoKy1aR4eLRD09L2WsxJAABAUlC2mhRXFCVJgzvYqykAAJgblK0mHQu6VdBOlUpxJwEAAElB2WpmpqINqTTMwwIAAOYGrWKaYras0s5s3DEAAEBCULamKebKKlVycccAAAAJQdmaZmHXiEojnXHHAAAACUHZmqbYNabBsXzcMQAAQEJQtqYpFsZVqlK2AADA3KBsTVPsqalU64k7BgAASAjK1jTFBa6S98rZrykAAJgDlK1pigulurIq7+Bg1AAA4MAFK1tmdouZbTWzjaHGCGHhIpMklV4aijkJAABIgpBrtv5R0tkBlx9EcXGHJGnwlXLMSQAAQBIEK1vu/qCkN0ItP5TikqhslV6rxJwEAAAkQezbbJnZGjNbZ2brBgYG4o6jYl+XJKm0dSTmJAAAIAliL1vufqO797t7f19fX9xxVDykW5JUep0N5AEAwIGLvWzNN8VDox2alraNx5wEAAAkAWVrmoVHRDs0LW2vxZwEAAAkQchdP9wu6WFJR5nZFjP7RKix5lLvEb2SpMEd7NUUAAAcuI5QC3b3j4Zadki5Yl55lVUqxZ0EAAAkQbCyddAyU9GGVBq23a/bvFl67jnphRekF15Q/fmfq1yqasHXbpQWL25/VgAAMO9RtlooZssq7cxOTXBX9crf00PXrdcTOlZP6Fg9qfO10Y7RiHdpw1/+i47+01+PLzAAAJi32EC+hWKurFIlF11w12tX/LHed90H9V49qN/Vdbpz0SfUdfp/0mW/vUBm0j/cXI83MAAAmLdYs9XCws4RDY5EOzf90Zqb9Gs3fULbsofoy9e5zvuA6YgjsrLGp4wv/eBFfXXj+3TtUz9Vx9FHxZgaAADMR6zZaqHYParSWLduvuBeveem31BuQZceerRDv/XfTMuXa7JoSdJln1qi13SYvvMnj8YXGAAAzFuUrRaK+ao2Vn5Bv3nnOXrvoT/VuueX6PgTWz9U5166REtzg7r1/y6R6nycCAAAdkXZamFZcVSujK55+526d9MqLT0kO+O8nZ3SR3/5Vd1dfp923PNQG1MCAICDAWWrhWv+sFMPX/C/9WdPnqdsd26v81/2h2/RqLr19WtfaEM6AABwMDH3+bOn9P7+fl+3bl3cMfaZu3TM4pe0aHiz/t/QcVI+H3ckAAAQmJmtd/f+vc3Hmq05YCZddlFZD9VO0XM33h93HAAAMI9QtubIxz73VmVU0z9dNxh3FAAAMI9QtubI8jdn9b6Vz+srP3u36q9ujTsOAACYJyhbc+iyTxb0olbqwT/9t7ijAACAeYKyNYc+9Dsr1JsZ1q23d8YdBQAAzBOUrTlUKEgXnrRJa7edoZ0PPR53HAAAMA9QtubY5Z9boWH16g/Oflz+GttuAQCQdpStOXbaeYv0e7/+iq4f+g39ef+/SOVy3JEAAECMKFsB/MU/Ha6PvmeLPrvlCt36npukWi3uSAAAICaUrQAyGekfv7tC7/vF/9AnHvtt3fuRm+KOBAAAYtIRd4Ck6uyU7vzxm3XG27bogrsv1f2f/med/BcXSps2aeDhn+nx+9/QxiddRxwh/fJFfeo7p19atCju2AAAYI5xbMTAXn2ppl/6xQENlbM6KfOYHq8fo5e1fLf5jtXjOnPpEzrz5GH1v7dHh65aosxb3iStWCEtWRIdEwgAAMwbsz02ImWrDZ57oqKPnLVdJtdxbytr9UkdOu6sPh1zygJtenZE3//qy/r+/dIPf75cI/UuSVKXRvQWvaiV2qSV2S3qK+xUR9aV7TB1dEgdOVNHTursyqizy9TZnVFnd0a57qyynVl1dGain13ZyZ8d3R3R+e4OZbtzynZ1KNPZoWzX1KnV5Yn5O7qjaR35nDK5rCxjdEAAQGpRtg5CIyPSow/V9NQjQ9r0TEWbnq/p51ty+vnrC7R9JK+6z89N7Ex1mVwZ1dWhqrJWV1a16KfVlbHouonzEwXNZU1r7CyaL+MySZmMK2OubGNaNhOdz2ai22cyLjObPJ/JRIvKTLs8MY9lJJkp0zhvZpNl0TJqzG9T1zXdxswmFjY5fZd5MxPXKzqvxrIzaoxhjTFMmawpm5UyWVMmOzUtk7UoQ2Yq1+Rj0xhj4vrJeU1yj07S1M9WOjqkbHbqZzY7tfjpt2u6u8pkouvr9V1PkqYe88zU/DP+jtjup4npEzKZqVPzcpuXn2l6CUy/360eh2x2apkT93lPp1ZjN2eaPt9M97v5ts3z7+nxmf5zNo/ZxHMz8XvgvvexWy1n+nV7y9TK3n4Ppy97+nKbH+vm/HvKvafHdKYcs/m9bb7t3u5PtSqNje16kqZebxOnid/d2f4O8kb24DDbssU2W/NId7f03jOzeu+ZiyTtvv2We/TFxloteoGPj0en6S/0ietr43VVK+OqVsZVG62qNjKu6khVtdGqqiNV1cdr0fTxumpjNdXGaqpXG+fH61Pnq67quEc/q67quFSvubxWl9dd9brLax5lq/pkxujk8rpUq1vjH7WrXpfMXfK6zOtS3eUezVd3qV636GfNVHNT3U21ekY1N9U8M/lPpe4m94mqZ6oro7pMVWVUU1bemB7VN03O0zy9+TT9uonbTD7+M9yu+fo9zecy1ZRt5MxMnp/KnmlcnlrihOb5AKRDxupR+ZJPFbKJ85LMpv5GNL95ctlUAZ/h/MRtJt4AN78Rnm5ivIk3wzbDvKbdm6l7NHVi/D3f36llR2++Gz8n3kSbR3/zG3//6/WJ/wPTMppU6Krp+Vd69vYQt03QsmVmZ0v6G0lZSTe5+7Uhx0s6s6l3SV1ds7lFRlJX45RgE389arWpVS8T55ub354uz3T75ssT802f3mrMVstrddtWq42mr6qYtvpisuTWXOaNqlmPlmW++2297lMFvCZVa1Ktql2WHZXe+q43bRTjyT96jT98lolWqXm1Nlm669W6vBZl2eU+1Gq7LdMbf3ndpzI0F+jJU9NdmfiDWneb+oPuUYWVN/7wN1+WT85f80zLZbpLqtcnlz3xz2Bi3ol/TJNFeGI+qfEH33Zd69G4MFmum+/PxD8Z33VeKfrn0Txp8p/ixD9GTY2zS7l3Nd5eTK1ZNvlu5b25oLd6EzH9ulbztXoT0uof68S06de1etMyfbnTc9eU3WPuvWWZKUfzG6uJsVrlnun89LFzGlenxiZPOY3L5KqqQzVlVVWHxpVr+Sat1Zu/5mzupnotM+ObxJkeiz2dmn9P3Hd93FuZKV+r+VpNa5VhT+Pscv9b/E5kdvvtru+Ws66MclmX9Fstx4pDsLJlZllJ10v6FUlbJP3YzL7p7k+HGhMp1eozpgSzxmm299YUvdBZjZ0CrT7kcayqAAAHp0lEQVT/2tdpu7S+GT5Pm+mztdl8tttq+XvLuLdx93XsmW6/t7Fbfbba0SHlctFX0HO56OQ+9fHDxM/mN1fTP/+dnrfVtFafJ09/U9acsTnnTMuf6XGa6fHZ02MxcX76afp9ndgOYaY8rcbb07gz3cd59jlsyL+/J0v6mbu/IElmdoekD0qibAFACLPdoAnh5XJSPh93CswTIVcFLJe0uenylsY0AACA1Ij9cxczW2Nm68xs3cDAQNxxAAAA5lTIsvWSpDc1XV7RmLYLd7/R3fvdvb+vry9gHAAAgPYLWbZ+LOntZnakmXVKukTSNwOOBwAAMO8E20De3atm9juS7lO064db3P2pUOMBAADMR0G/De7u90i6J+QYAAAA81nsG8gDAAAkGWULAAAgIMoWAABAQJQtAACAgMz3djykNjKzAUkvBh5mmaTXA4+BfcfzMn/x3MxPPC/zF8/N/BTieXmLu+91J6Hzqmy1g5mtc/f+uHNgVzwv8xfPzfzE8zJ/8dzMT3E+L3yMCAAAEBBlCwAAIKA0lq0b4w6Alnhe5i+em/mJ52X+4rmZn2J7XlK3zRYAAEA7pXHNFgAAQNtQtgAAAAJKTdkys7PN7Kdm9jMzuybuPGlmZm8ys/vN7Gkze8rMrmpMX2Jm/2pmzzV+Lo47axqZWdbMfmJm32pcPtLMHm28dv7ZzDrjzphGZrbIzNaa2bNm9oyZvZvXTPzM7FONv2Mbzex2M+vmNRMPM7vFzLaa2camaS1fIxb5UuM5esLMTgiZLRVly8yykq6XdI6kd0r6qJm9M95UqVaV9Gl3f6ekUyRd0Xg+rpH0PXd/u6TvNS6j/a6S9EzT5T+X9Nfu/jZJ2yV9IpZU+BtJ33b3d0g6TtFzxGsmRma2XNKVkvrd/RhJWUmXiNdMXP5R0tnTps30GjlH0tsbpzWSbggZLBVlS9LJkn7m7i+4+5ikOyR9MOZMqeXur7j7Y43zQ4r+aSxX9Jzc2pjtVkkfiidhepnZCknvl3RT47JJOlPS2sYsPC8xMLOFkk6XdLMkufuYu+8Qr5n5oENS3sw6JBUkvSJeM7Fw9wclvTFt8kyvkQ9K+opHHpG0yMwOD5UtLWVruaTNTZe3NKYhZma2UtLxkh6VdKi7v9K46lVJh8YUK82+KOn3JdUbl5dK2uHu1cZlXjvxOFLSgKR/aHzEe5OZ9YjXTKzc/SVJfynpPxSVrEFJ68VrZj6Z6TXS1l6QlrKFecjMFki6U9LV7l5qvs6jfZKwX5I2MrPzJG119/VxZ8FuOiSdIOkGdz9e0k5N+8iQ10z7Nbb/+aCiMnyEpB7t/jEW5ok4XyNpKVsvSXpT0+UVjWmIiZnlFBWt29z9G43Jr02sxm383BpXvpQ6VdL5ZrZJ0UftZyraTmhR4yMSiddOXLZI2uLujzYur1VUvnjNxOt9kn7u7gPuPi7pG4peR7xm5o+ZXiNt7QVpKVs/lvT2xjdEOhVtwPjNmDOlVmM7oJslPePuf9V01TclXdY4f5mku9udLc3c/bPuvsLdVyp6jXzf3T8m6X5JFzRm43mJgbu/KmmzmR3VmHSWpKfFayZu/yHpFDMrNP6uTTwvvGbmj5leI9+U9BuNbyWeImmw6ePGOZeaPcib2bmKtkfJSrrF3f805kipZWanSfo3SU9qatug/6Fou62vS3qzpBclXeTu0zd2RBuY2RmSPuPu55nZWxWt6Voi6SeSLnX30TjzpZGZrVb0xYVOSS9IulzRG2ZeMzEysz+SdLGib1n/RNJvKtr2h9dMm5nZ7ZLOkLRM0muS/pek/6MWr5FGOb5O0ce+ZUmXu/u6YNnSUrYAAADikJaPEQEAAGJB2QIAAAiIsgUAABAQZQsAACAgyhYAAEBAlC0AqWRmZ5jZt+LOASD5KFsAAAABUbYAzGtmdqmZ/cjMNpjZ35lZ1syGzeyvzewpM/uemfU15l1tZo+Y2RNmdlfj2HUys7eZ2XfN7HEze8zMfqGx+AVmttbMnjWz2xo7OpSZXWtmTzeW85cx3XUACUHZAjBvmdkqRXvnPtXdV0uqSfqYogP+rnP3oyX9QNGeoiXpK5L+wN2PVXSEgonpt0m63t2Pk/RLkiYOy3G8pKslvVPSWyWdamZLJX1Y0tGN5Xwh7L0EkHSULQDz2VmSTpT0YzPb0Lj8VkWHefrnxjxflXSamS2UtMjdf9CYfquk082sV9Jyd79Lktx9xN3LjXl+5O5b3L0uaYOklZIGJY1IutnMfk3RoTwAYL9RtgDMZybpVndf3Tgd5e6fbzHf/h53rPl4dTVJHe5elXSypLWSzpP07f1cNgBIomwBmN++J+kCMztEksxsiZm9RdHfrgsa8/y6pB+6+6Ck7Wb2nsb0j0v6gbsPSdpiZh9qLKPLzAozDWhmCyQtdPd7JH1K0nEh7hiA9OiIOwAAzMTdnzazz0n6jpllJI1LukLSTkknN67bqmi7Lkm6TNKXG2XqBUmXN6Z/XNLfmdkfN5Zx4R6G7ZV0t5l1K1qz9ntzfLcApIy57+/adwCIh5kNu/uCuHMAwGzwMSIAAEBArNkCAAAIiDVbAAAAAVG2AAAAAqJsAQAABETZAgAACIiyBQAAEND/BwhhKub7OajrAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 720x360 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 可以调整的超参\n",
    "batch_size = 128\n",
    "epochs = 100\n",
    "lr = 0.01\n",
    "wd = 0\n",
    "use_gpu = False\n",
    "\n",
    "net = get_model()\n",
    "train_model(net, train_features, train_labels, valid_features, valid_labels, epochs, \n",
    "            batch_size, lr, wd, use_gpu)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当我们构建好了训练的过程，下面就开始了不断地调参尝试，最后得到一个效果最好的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 10, train rmse: 0.178\n",
      "\n",
      "epoch: 20, train rmse: 0.137\n",
      "\n",
      "epoch: 30, train rmse: 0.120\n",
      "\n",
      "epoch: 40, train rmse: 0.111\n",
      "\n",
      "epoch: 50, train rmse: 0.106\n",
      "\n",
      "epoch: 60, train rmse: 0.103\n",
      "\n",
      "epoch: 70, train rmse: 0.101\n",
      "\n",
      "epoch: 80, train rmse: 0.099\n",
      "\n",
      "epoch: 90, train rmse: 0.098\n",
      "\n",
      "epoch: 100, train rmse: 0.098\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlsAAAFACAYAAACLPLm0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHn5JREFUeJzt3X2UXHWd5/H3N51OqvNAN4SWcRMxcQdZ1JEgkYMD47A+DQ+O4IJPKyquezJzjq7o6s7qrmd33XE8OjszKjOog8KIDqIuyMPxID6j66BokKgRUARF4lFpkECAhCTku3/cKlNpqpNOUr9b3X3fr3Puqapbt+79Vldu59O/369+NzITSZIklTFv0AVIkiTNZYYtSZKkggxbkiRJBRm2JEmSCjJsSZIkFWTYkiRJKsiwJUmSVJBhS5IkqSDDliRJUkHzB11At0MPPTRXrlw56DIkSZL26sYbb7wnM8f3tt2MClsrV65k3bp1gy5DkiRpryLizulsZzeiJElSQYYtSZKkggxbkiRJBc2oMVuSJGl22L59Oxs3bmTr1q2DLqW4VqvFihUrGB4e3q/XG7YkSdI+27hxI0uXLmXlypVExKDLKSYzuffee9m4cSOrVq3ar33YjShJkvbZ1q1bWbZs2ZwOWgARwbJlyw6oBc+wJUmS9stcD1odB/o+DVuSJEkFGbYkSdKss2nTJj74wQ/u8+tOPfVUNm3aVKCiqTUrbF1+OXzhC4OuQpIkHaCpwtaOHTv2+LprrrmGsbGxUmX11KxvI77rXfCEJ8Cf/MmgK5EkSQfgbW97G7fffjurV69meHiYVqvFwQcfzK233spPfvITzjjjDO666y62bt3Kueeey9q1a4FdlwZ88MEHOeWUUzjxxBO5/vrrWb58OVdddRUjIyN9r7VZYavVgi1bBl2FJElzy5veBOvX93efq1fD+98/5dPvec972LBhA+vXr+e6667jtNNOY8OGDb+bnuGiiy7ikEMOYcuWLTzzmc/kzDPPZNmyZbvt47bbbuPSSy/lIx/5CC996Uu5/PLLOfvss/v7Pmha2BoZgQZMviZJUtMcd9xxu82Ddd5553HFFVcAcNddd3Hbbbc9JmytWrWK1atXA3Dsscfy85//vEhtxcJWRBwJfLpr1ZOA/5GZU8fU0kZG4IEHBnZ4SZLmpD20QNVl8eLFv7t/3XXX8eUvf5lvfetbLFq0iJNOOqnnPFkLFy783f2hoSG2FOr9Kha2MvPHwGqAiBgCfglcUep40zIyYjeiJElzwNKlS9m8eXPP5+6//34OPvhgFi1axK233sq3v/3tmqvbXV3diM8Fbs/MO2s6Xm+tlt2IkiTNAcuWLeOEE07gaU97GiMjIxx22GG/e+7kk0/mwx/+MEcddRRHHnkkxx9//AArrS9svRy4tNcTEbEWWAtw+OGHl63Cli1JkuaMT37ykz3XL1y4kM9//vM9n+uMyzr00EPZsGHD79a/9a1v7Xt9HcXn2YqIBcCLgP/b6/nMvCAz12TmmvHx8bLFGLYkSVLN6pjU9BTge5n5mxqOtWd2I0qSpJrVEbZewRRdiLXrTP2QOehKJEma9bIh/58e6PssGrYiYjHwfOCzJY8zbZ1ZYW3dkiTpgLRaLe699945H7gyk3vvvZdWq7Xf+yg6QD4zHwKW7XXDunR+UFu27ApekiRpn61YsYKNGzcyMTEx6FKKa7VarFixYr9f37wZ5MGWLUmSDtDw8PBuM7ZranWM2Zo5OmHLbyRKkqSaNCtsdXcjSpIk1aBZYctuREmSVLNmhi1btiRJUk0MW5IkSQU1K2x1xmzZjShJkmrSrLBly5YkSaqZYUuSJKmgZoUtp36QJEk1a1bYcuoHSZJUs2aGLVu2JElSTZoVthYsgAjDliRJqk2zwlZENW7LbkRJklSTZoUtqLoSbdmSJEk1MWxJkiQV1LywZTeiJEmqUfPCli1bkiSpRoYtSZKkgpoXtlotw5YkSapN88LWyIhjtiRJUm2aGbZs2ZIkSTVpXtiyG1GSJNWoeWHLbkRJklSjZoYtW7YkSVJNDFuSJEkFFQ1bETEWEZdFxK0RcUtEPKvk8abFGeQlSVKN5hfe/weAazPzrIhYACwqfLy9GxmBRx+F7dtheHjQ1UiSpDmuWNiKiFHg2cA5AJm5DdhW6njTNjJS3W7ZYtiSJEnFlexGXAVMAP8UETdFxEcjYvHkjSJibUSsi4h1ExMTBctpa7WqW8dtSZKkGpQMW/OBZwAfysxjgIeAt03eKDMvyMw1mblmfHy8YDltnZYtx21JkqQalAxbG4GNmXlD+/FlVOFrsLq7ESVJkgorFrYy89fAXRFxZHvVc4GbSx1v2uxGlCRJNSr9bcT/BFzS/ibiHcBrCx9v7+xGlCRJNSoatjJzPbCm5DH2md2IkiSpRs2cQR4MW5IkqRbNC1udMVt2I0qSpBo0L2zZsiVJkmpk2JIkSSqoeWHLqR8kSVKNmhe2nPpBkiTVqHlhy5YtSZJUo+aFraEhGB42bEmSpFo0L2xB1ZVoN6IkSapBc8OWLVuSJKkGhi1JkqSCmhm2Wi27ESVJUi2aGbZs2ZIkSTUxbEmSJBXUzLDVahm2JElSLZoZtpz6QZIk1aS5YcuWLUmSVINmhi27ESVJUk2aGbbsRpQkSTVpbtiyZUuSJNWgmWHLbkRJklSTZoatkRHYtg127hx0JZIkaY5rbtgCx21JkqTimh227EqUJEmFNTNstVrVrS1bkiSpsGaGLVu2JElSTeaX3HlE/BzYDDwK7MjMNSWPN22GLUmSVJOiYavt32bmPTUcZ/o63YiGLUmSVFizuxEdsyVJkgorHbYS+GJE3BgRa3ttEBFrI2JdRKybmJgoXE6b3YiSJKkmpcPWiZn5DOAU4PUR8ezJG2TmBZm5JjPXjI+PFy6nzW5ESZJUk6JhKzN/2b69G7gCOK7k8abNbkRJklSTYmErIhZHxNLOfeAFwIZSx9sndiNKkqSalPw24mHAFRHROc4nM/PagsebPsOWJEmqSbGwlZl3AEeX2v8BcQZ5SZJUk2ZP/WDLliRJKqyZYWt4GObNM2xJkqTimhm2IqquRMOWJEkqrJlhC6quRMdsSZKkwpodtmzZkiRJhTU3bNmNKEmSatDcsGU3oiRJqkGzw5YtW5IkqTDDliRJUkHNDVutlt2IkiSpuOaGLVu2JElSDQxbkiRJBTU3bDn1gyRJqkFzw5ZTP0iSpBo0O2zZsiVJkgprbtjqdCNmDroSSZI0hzU3bI2MVEFr+/ZBVyJJkuawZoctsCtRkiQVZdgybEmSpIKaG7ZarerWbyRKkqSCmhu2bNmSJEk1MGwZtiRJUkHNDVudbkTDliRJKqi5YavTsuWYLUmSVJBhy5YtSZJUUHPDlt2IkiSpBsXDVkQMRcRNEfG50sfaJ3YjSpKkGtTRsnUucEsNx9k3diNKkqQaFA1bEbECOA34aMnj7Be7ESVJUg1Kt2y9H/gLYOdUG0TE2ohYFxHrJiYmCpfTxW5ESZJUg2JhKyJeCNydmTfuabvMvCAz12TmmvHx8VLlPJYtW5IkqQYlW7ZOAF4UET8HPgU8JyL+ueDx9s28ebBwoWFLkiQVVSxsZebbM3NFZq4EXg58NTPPLnW8/dJqGbYkSVJRzZ1nC6pxW47ZkiRJBc2v4yCZeR1wXR3H2icjI7ZsSZKkoqbVshUR50bEQVG5MCK+FxEvKF1ccXYjSpKkwqbbjfgfMvMB4AXAwcCrgPcUq6oudiNKkqTCphu2on17KvCJzPxR17rZy25ESZJU2HTD1o0R8UWqsPWFiFjKHiYqnTXsRpQkSYVNd4D864DVwB2Z+XBEHAK8tlxZNRkZgd/+dtBVSJKkOWy6LVvPAn6cmZsi4mzgHcD95cqqid2IkiSpsOmGrQ8BD0fE0cBbgNuBjxerqi6GLUmSVNh0w9aOzEzgdOAfMvN8YGm5smrimC1JklTYdMdsbY6It1NN+fBHETEPGC5XVk2c+kGSJBU23ZatlwGPUM239WtgBfB/ilVVF7sRJUlSYdMKW+2AdQkwGhEvBLZm5uwfs9Vqwfbt8Oijg65EkiTNUdO9XM9Lge8ALwFeCtwQEWeVLKwWIyPVrV2JkiSpkOmO2frvwDMz826AiBgHvgxcVqqwWnTC1pYtsHjxYGuRJElz0nTHbM3rBK22e/fhtTNXq1XdOm5LkiQVMt2WrWsj4gvApe3HLwOuKVNSjexGlCRJhU0rbGXmf4mIM4ET2qsuyMwrypVVk+5uREmSpAKm27JFZl4OXF6wlvoZtiRJUmF7DFsRsRnIXk8BmZkHFamqLo7ZkiRJhe0xbGXm7L8kz544ZkuSJBU2+79ReCDsRpQkSYU1O2zZjShJkgprdtiyG1GSJBVm2AJbtiRJUjHNDlt2I0qSpMKaHbbsRpQkSYU1O2wND8PQkC1bkiSpmGJhKyJaEfGdiPh+RPwoIt5Z6lgHZGTEsCVJkoqZ9uV69sMjwHMy88GIGAa+GRGfz8xvFzzmvmu1DFuSJKmYYmErMxN4sP1wuL30uvTPYI2MOGZLkiQVU3TMVkQMRcR64G7gS5l5Q49t1kbEuohYNzExUbKc3uxGlCRJBRUNW5n5aGauBlYAx0XE03psc0FmrsnMNePj4yXL6c1uREmSVFAt30bMzE3A14CT6zjePrEbUZIkFVTy24jjETHWvj8CPB+4tdTx9pvdiJIkqaCS30Z8PHBxRAxRhbrPZObnCh5v/7RacM89g65CkiTNUSW/jfgD4JhS+++bxYvhzjsHXYUkSZqjmj2DPMDYGNx//6CrkCRJc5Rha3QUNm0adBWSJGmOMmyNjcHDD8P27YOuRJIkzUGGrdHR6tauREmSVIBhy7AlSZIKMmyNjVW3hi1JklSAYavTsuUgeUmSVIBhy25ESZJUkGGr041oy5YkSSrAsGXLliRJKsiwddBB1a1hS5IkFWDYmj8fliyxG1GSJBVh2IKqK9GWLUmSVIBhC7wYtSRJKsawBV6MWpIkFWPYArsRJUlSMYYtqLoRbdmSJEkFGLbAli1JklSMYQt2DZDPHHQlkiRpjjFsQdWytX07bNky6EokSdIcY9gCL9kjSZKKMWzBrotRG7YkSVKfGbZgV8uW30iUJEl9ZtgCuxElSVIxhi3Y1Y1oy5YkSeozwxbYsiVJkoopFrYi4gkR8bWIuDkifhQR55Y61gFzgLwkSSpkfsF97wDekpnfi4ilwI0R8aXMvLngMffP4sUwNGQ3oiRJ6rtiLVuZ+avM/F77/mbgFmB5qeMdkAg46CBbtiRJUt/VMmYrIlYCxwA39HhubUSsi4h1ExMTdZTTW+eSPZIkSX1UPGxFxBLgcuBNmfnA5Ocz84LMXJOZa8bHx0uXM7XRUbsRJUlS3xUNWxExTBW0LsnMz5Y81gEbHbVlS5Ik9V3JbyMGcCFwS2b+Xanj9M3YmC1bkiSp70q2bJ0AvAp4TkSsby+nFjzegbFlS5IkFVBs6ofM/CYQpfbfdw6QlyRJBTiDfMfoKDzwAOzcOehKJEnSHGLY6hgdhUzYvHnQlUiSpDnEsNXhJXskSVIBhq2OzsWo/UaiJEnqI8NWRyds2bIlSZL6yLDV0elGtGVLkiT1kWGrw5YtSZJUgGGrwwHykiSpAMNWhwPkJUlSAYatjoULq8WWLUmS1EeGrW5eskeSJPWZYavb6KjdiJIkqa8MW91GR23ZkiRJfWXY6jY2ZsuWJEnqK8NWN1u2JElSnxm2ujlAXpIk9Zlhq5sD5CVJUp8ZtrqNjsKWLbB9+6ArkSRJc4Rhq5uX7JEkSX1m2OrmJXskSVKfGba6dcKWLVuSJKlPDFvdOt2ItmxJkqQ+MWx1s2VLkiT1mWGrmwPkJUlSnxm2ujlAXpIk9Zlhq9tBB1W3tmxJkqQ+KRa2IuKiiLg7IjaUOkbfDQ3B0qWGLUmS1DclW7Y+BpxccP9leMkeSZLUR8XCVmZ+A/htqf0XMzpqy5YkSeqbgY/Zioi1EbEuItZNTEwMupzqG4m2bEmSpD4ZeNjKzAsyc01mrhkfHx90ObZsSZKkvhp42JpxxsYMW5IkqW8MW5NNNUD+hz+Ec8+Fbdvqr0mSJM1aJad+uBT4FnBkRGyMiNeVOlZfdboRM3df/773wXnnwfvfP5i6JEnSrFTy24ivyMzHZ+ZwZq7IzAtLHauvxsZgxw7YsmXXuh074OqrIQLe+U64667B1SdJkmYVuxEn63XJnuuvh3vvhfe+t2rxevObB1ObJEmadQxbk3XCVvcg+SuvhAUL4M//HN7xDrj8crj22sHUJ0mSZhXD1mRjY9Vtp2Urswpbz3tedSmft7wFnvxkeMMbYOvWwdUpSZJmBcPWZJNbtn74Q/jZz+DFL64eL1wI558Pt98Of/3Xg6lRkiTNGoatyTotW52wdeWV1cD4P/3TXds873nwspfBu99dhS5JkqQpGLYmmzxA/sor4Q//EA47bPft/vZvYXgY3vjGx04TIUmS1GbYmqy7G/HOO+Gmm+CMMx673fLl1TQQ11wD//Iv9dYoSZJmDcPWZIsXw9BQFbauuqpad/rpvbf9sz+DJUvgYx+rrTxJkjS7GLYmi9h1yZ4rr4SnPhWOOKL3tosXw0teAp/5DDz0UL11SpKkWcGw1cvoKNxxB3zjG727ELudcw5s3gxXXFFLaZIkaXYxbPUyNgZf+hI8+ujew9aJJ8KqVXYlSpKkngxbvYyOVkFr+XI49tg9bztvHrzmNfDVr8IvflFPfZIkadYwbPXSmWvrjDOqMVx78+pXV9M/fOITZeuSJEmzjmGrl870D3vrQuxYtQpOOqnqSnTOLUmS1MWw1cuqVfB7vwd//MfTf80558BPfwrXX1+sLEmSNPsYtnp5+9vh5purGeKn68wzq6kgLr64XF2SJGnWMWz1smABHHzwvr1myRI46yz49Kfh4YfL1CVJkmYdw1Y/nXMOPPBANRmqJEkShq3+evazYeVK59ySJEm/M3/QBcwp8+ZV00D85V/CO94By5bBIYdUXZKHHAJPfjI87nGDrlKSJNXIsNVvr3sdXHghvPvdvaeBOOwwePrT4eijd90eddS+DcaXJEmzhmGr3w4/HDZuhJ074f774b77quWee+CWW+D734cf/AD+/u/hkUeq1yxcCH/wB3DMMdVy+OHVNxsXLdp1u2RJNf/XggWDfX+SJGmfRM6gSTjXrFmT69atG3QZ9dixA378Y1i/Hm66addy3317ft3ChXDQQdWydCmMjOy+LFoErVa1Xau1+/2FC3ctrVYV3LrXLVxYretehod33XaWeQ71kyQpIm7MzDV7286WrUGZPx+e+tRqeeUrq3WZ1fUVf/MbeOihagqJzu3mzdXywAO7L1u2VMumTdV2W7ZULWZbt+5a+m1oaPcQNjmMzZ//2NvJ9+fN233p3mf3fjvbdy9DQ9XSfX+qpbPvvT3uvp3q/lTP72mJ2P3+dC7/JEmaUwxbM0kEPPGJ1dIvmbBtWxXAupetW6vbyc9t2wbbt1e3nee2b9996WzTfb9zu2PH7rfbt1cX9X7kkSo4dtZnVl2tO3dWzz/66GP3t21btX7Hjv79PGaC7hDWHcZ6re8EtO4Fdt3fWwicvJ+p9rmv66e7dNe6p+emut/9ePL7nnx/8nZT3e6tno491dpr39M57nRqnHy/21R1T95mqnVTvZepXjtVz8dU63v9gTFdvX7Wmbt+V3Tud47TvXSe6156bbenz7pXPfu7bl/t6d9CR/f76v75T/ec6Ufd3a/p/P7uLN0/86l+h3T/m+j1XibXvb+1QfW78LTT9m0fBRm25rqIXV2Es1Xnl20nqHXCWSeIdT/uXrqD3N4edwe/7tvJgTDzsdt3B8de6zqv6f6Po/u5yesmr5/8H0jndvJrumucan+99rmn9d117+vSqXVvz011f/L7narenTsfu12v287PbE/1dP+bm6q+qfa9p+PPoOEaUiOMjMyoCcaLhq2IOBn4ADAEfDQz31PyeJqjInZ1+83m0Ch17CkQ7qk1qVdI7LXfvR1ruq+FPbey9aqvOwTv3Dm9Foqpwnav1tXONt1/ZEzV4tfrD4/J73Vff47TWbcvrTJ7+rcweZ+TW66m+wdKP+qe/DPr1Yo+1R9uk/+AnKolb3Ld0/059no//Whx7KNiYSsihoDzgecDG4HvRsTVmXlzqWNK0qywv10lkmalkl8rOw74aWbekZnbgE8Bpxc8niRJ0oxTMmwtB+7qeryxvW43EbE2ItZFxLqJiYmC5UiSJNVv4BMmZeYFmbkmM9eMj48PuhxJkqS+Khm2fgk8oevxivY6SZKkxigZtr4LHBERqyJiAfBy4OqCx5MkSZpxin0bMTN3RMQbgC9QTf1wUWb+qNTxJEmSZqKi82xl5jXANSWPIUmSNJMNfIC8JEnSXGbYkiRJKsiwJUmSVFDkDLpAakRMAHcWPsyhwD2Fj6F95+cyc/nZzEx+LjOXn83MVOJzeWJm7nWS0BkVtuoQEesyc82g69Du/FxmLj+bmcnPZebys5mZBvm52I0oSZJUkGFLkiSpoCaGrQsGXYB68nOZufxsZiY/l5nLz2ZmGtjn0rgxW5IkSXVqYsuWJElSbQxbkiRJBTUmbEXEyRHx44j4aUS8bdD1NFlEPCEivhYRN0fEjyLi3Pb6QyLiSxFxW/v24EHX2kQRMRQRN0XE59qPV0XEDe1z59MRsWDQNTZRRIxFxGURcWtE3BIRz/KcGbyIeHP799iGiLg0IlqeM4MRERdFxN0RsaFrXc9zJCrntT+jH0TEM0rW1oiwFRFDwPnAKcBTgFdExFMGW1Wj7QDekplPAY4HXt/+PN4GfCUzjwC+0n6s+p0L3NL1+L3A+zLz94H7gNcNpCp9ALg2M/8NcDTVZ+Q5M0ARsRx4I7AmM58GDAEvx3NmUD4GnDxp3VTnyCnAEe1lLfChkoU1ImwBxwE/zcw7MnMb8Cng9AHX1FiZ+avM/F77/maq/zSWU30mF7c3uxg4YzAVNldErABOAz7afhzAc4DL2pv4uQxARIwCzwYuBMjMbZm5Cc+ZmWA+MBIR84FFwK/wnBmIzPwG8NtJq6c6R04HPp6VbwNjEfH4UrU1JWwtB+7qeryxvU4DFhErgWOAG4DDMvNX7ad+DRw2oLKa7P3AXwA724+XAZsyc0f7sefOYKwCJoB/anfxfjQiFuM5M1CZ+Uvgb4BfUIWs+4Eb8ZyZSaY6R2rNBU0JW5qBImIJcDnwpsx8oPu5rOYkcV6SGkXEC4G7M/PGQdeix5gPPAP4UGYeAzzEpC5Dz5n6tcf/nE4Vhv8VsJjHdmNphhjkOdKUsPVL4Aldj1e012lAImKYKmhdkpmfba/+TacZt31796Dqa6gTgBdFxM+putqfQzVOaKzdRQKeO4OyEdiYmTe0H19GFb48ZwbrecDPMnMiM7cDn6U6jzxnZo6pzpFac0FTwtZ3gSPa3xBZQDWA8eoB19RY7XFAFwK3ZObfdT11NfCa9v3XAFfVXVuTZebbM3NFZq6kOke+mpmvBL4GnNXezM9lADLz18BdEXFke9VzgZvxnBm0XwDHR8Si9u+1zufiOTNzTHWOXA28uv2txOOB+7u6G/uuMTPIR8SpVONRhoCLMvOvBlxSY0XEicD/A37IrrFB/41q3NZngMOBO4GXZubkwY6qQUScBLw1M18YEU+iauk6BLgJODszHxlkfU0UEaupvriwALgDeC3VH8yeMwMUEe8EXkb1LeubgP9INfbHc6ZmEXEpcBJwKPAb4H8CV9LjHGmH43+g6vZ9GHhtZq4rVltTwpYkSdIgNKUbUZIkaSAMW5IkSQUZtiRJkgoybEmSJBVk2JIkSSrIsCWpkSLipIj43KDrkDT3GbYkSZIKMmxJmtEi4uyI+E5ErI+If4yIoYh4MCLeFxE/ioivRMR4e9vVEfHtiPhBRFzRvnYdEfH7EfHliPh+RHwvIv51e/dLIuKyiLg1Ii5pT3RIRLwnIm5u7+dvBvTWJc0Rhi1JM1ZEHEU1O/cJmbkaeBR4JdUFf9dl5lOBr1PNFA3wceC/ZubTqa5Q0Fl/CXB+Zh4N/CHQuSzHMcCbgKcATwJOiIhlwIuBp7b3866y71LSXGfYkjSTPRc4FvhuRKxvP34S1WWePt3e5p+BEyNiFBjLzK+3118MPDsilgLLM/MKgMzcmpkPt7f5TmZuzMydwHpgJXA/sBW4MCL+HdWlPCRpvxm2JM1kAVycmavby5GZ+b96bLe/1x3rvl7do8D8zNwBHAdcBrwQuHY/9y1JgGFL0sz2FeCsiHgcQEQcEhFPpPrddVZ7m38PfDMz7wfui4g/aq9/FfD1zNwMbIyIM9r7WBgRi6Y6YEQsAUYz8xrgzcDRJd6YpOaYP+gCJGkqmXlzRLwD+GJEzAO2A68HHgKOaz93N9W4LoDXAB9uh6k7gNe2178K+MeI+N/tfbxkD4ddClwVES2qlrX/3Oe3JalhInN/W98laTAi4sHMXDLoOiRpOuxGlCRJKsiWLUmSpIJs2ZIkSSrIsCVJklSQYUuSJKkgw5YkSVJBhi1JkqSC/j+rERwxTvbtOQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 720x360 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "net = get_model()\n",
    "train_model(net, train_valid_features, train_valid_labels, None, None, epochs, \n",
    "            batch_size, lr, wd, use_gpu)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "运行下面的代码，可以通过训练好的模型预测 testset 的结果，会在当前目录生成 `submission.csv` 用于提交"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred(net, test, test_features)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
